#model4.py
import torchvision.models as models
import torch.nn as nn

def get_model(num_classes):
    # 加载预训练的 ResNet50 模型，使用新的 weights 参数
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    # 全连接层
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model